#
# Copyright (c) 2013, 2014, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; version 2 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
#

"""
compare_db_mysql test.
"""

import os

import mutlib

from mysql.utilities.common.tools import get_tool_path
from mysql.utilities.exception import MUTLibError, UtilError


_INPUT_SQL_FILE = 'test_data.sql'


class test(mutlib.System_test):
    """Test mysqldbcompare --difftype=sql generation from mysql client.

    This module test the diff SQL generated by the mysqdbcompare tool when data
    is inserted into the database using the mysql client.
    """

    server1 = None
    server2 = None
    mysql_path = None

    def check_prerequisites(self):
        # Need at least one base server
        return self.check_num_servers(1)

    def setup(self):
        self.res_fname = "result.txt"

        # Get path to mysql client tool from base server
        base_server = self.servers.get_server(0)
        if not base_server:
            raise MUTLibError("Unable to get base server.")
        rows = base_server.exec_query("SHOW VARIABLES LIKE 'basedir'")
        if rows:
            basedir = rows[0][1]
        else:
            raise MUTLibError("Unable to determine 'basedir' for base server.")

        try:
            self.mysql_path = get_tool_path(basedir, "mysql", quote=True)
        except UtilError as err:
            raise MUTLibError("Unable to find mysql client tool for server "
                              "{0}@{1} on basedir={2}. "
                              "ERROR: {3}".format(base_server.host,
                                                  base_server.port, basedir,
                                                  err.errmsg))

        # Spawn required servers
        num_servers = self.servers.num_servers()
        if num_servers < 3:
            try:
                self.servers.spawn_new_servers(3)
            except MUTLibError as err:
                raise MUTLibError(
                    "Cannot spawn needed servers: {0}".format(err.errmsg))

        # Set spawned servers
        self.server1 = self.servers.get_server(1)
        self.server2 = self.servers.get_server(2)

        # SQL data to execute on servers
        common_data = (
            "CREATE DATABASE util_test_mysql_client;\n"
            "CREATE TABLE util_test_mysql_client.t1 (a char(30), "
            "PRIMARY KEY (a)) ENGINE=InnoDB DEFAULT CHARSET=latin1;\n"
        )

        # Write common data to file
        with open(_INPUT_SQL_FILE, 'w') as sql_file:
            sql_file.write(common_data)

        # Execute common SQL data on both servers
        self.exec_mysql_client_cmd(self.server1, _INPUT_SQL_FILE)
        self.exec_mysql_client_cmd(self.server2, _INPUT_SQL_FILE)

        return True

    def exec_mysql_client_cmd(self, server, input_file):
        """Execute statements in input file with mysql client.

        Execute the statements in the given input file on the specified servers
        using the mysql client tool.

        server[in]      target server to execute the SQL statements.
        input_file[in]  input file with the SQL statements to execute.

        Returns the result of the execution.
        """

        # Get parameter to invoke mysql client.
        conn_val = self.get_connection_values(server)
        cmd_list = ["{0} -uroot".format(self.mysql_path)]
        if conn_val[1]:
            cmd_list.append("-p{0}".format(conn_val[1]))
        if conn_val[2]:
            if conn_val[2] != "localhost":
                if ']' in conn_val[2]:
                    host = conn_val[2].replace('[', '')
                    host = host.replace(']', '')
                    cmd_list.append("-h {0}".format(host))
                else:
                    cmd_list.append("-h {0}".format(conn_val[2]))
            else:
                cmd_list.append("-h 127.0.0.1")
        if conn_val[3]:
            cmd_list.append("--port={0}".format(conn_val[3]))
        if conn_val[4]:
            cmd_list.append("--socket={0}".format(conn_val[4]))
        cmd_list.append("< {0}".format(input_file))
        cmd = " ".join(cmd_list)

        # Execute mysql client command
        return self.exec_util(cmd, self.res_fname, True)

    def run(self):
        srv1_con = self.build_connection_string(self.server1).strip(' ')
        srv2_con = self.build_connection_string(self.server2).strip(' ')

        # Baseline test - No differences
        test_num = 1
        comment = "Test case {0} - no differences.".format(test_num)
        cmd_str = ("mysqldbcompare.py --server1={0} --server2={1} "
                   "util_test_mysql_client:util_test_mysql_client "
                   "--difftype=sql --changes-for=server1 "
                   "-t ".format(srv1_con, srv2_con))
        res = self.run_test_case(0, cmd_str, comment)
        if not res:
            raise MUTLibError("{0}: failed".format(comment))

        # Different PROCEDURE using ';' in body.
        test_num += 1
        diff_data1 = (
            "USE util_test_mysql_client;\n"
            "DELIMITER //\n"
            "CREATE DEFINER=`root`@`localhost` PROCEDURE `p1`(OUT param1 "
            "INT)\n"
            "BEGIN\n"
            "  SELECT COUNT(*) INTO param1 FROM mysql.user;\n"
            "END //\n"
            "DELIMITER ;\n"
        )
        diff_data2 = (
            "USE util_test_mysql_client;\n"
            "DELIMITER //\n"
            "CREATE DEFINER=`root`@`localhost` PROCEDURE `p1`(OUT param1 "
            "INT)\n"
            "BEGIN\n"
            "  -- test comment\n"
            "  SELECT COUNT(*) INTO param1 FROM mysql.user;\n"
            "END //\n"
            "DELIMITER ;\n"
        )
        # Write diff data for server1 to file
        with open(_INPUT_SQL_FILE, 'w') as sql_file:
            sql_file.write(diff_data1)
            # Execute diff SQL data on server1
        res = self.exec_mysql_client_cmd(self.server1, _INPUT_SQL_FILE)
        # Write diff data for server2 to file
        with open(_INPUT_SQL_FILE, 'w') as sql_file:
            sql_file.write(diff_data2)
            # Execute diff SQL data on server2
        res = self.exec_mysql_client_cmd(self.server2, _INPUT_SQL_FILE)
        # Get diff SQL for server1 from server2
        comment = ("Test case {0}a - difference: PROCEDURE with ';' "
                   "(get diff SQL).".format(test_num))
        cmd_str = ("mysqldbcompare.py --server1={0} --server2={1} "
                   "util_test_mysql_client:util_test_mysql_client "
                   "--difftype=sql --changes-for=server1 "
                   "-t > {2}".format(srv1_con, srv2_con, _INPUT_SQL_FILE))
        res = self.run_test_case(1, cmd_str, comment)
        if not res:
            raise MUTLibError("{0}: failed".format(comment))
            # Append output difference for server1 to result file.
        if self.debug:
            print "\nContents of SQL file:"
        sql_file = open(_INPUT_SQL_FILE, 'r')
        for line in sql_file:
            self.results.append(line)
            if self.debug:
                print line,
        sql_file.close()
        # Execute diff from mysqldbcompare on server1
        res = self.exec_mysql_client_cmd(self.server1, _INPUT_SQL_FILE)
        # Compare server1 and server2 - No difference expected now
        comment = ("Test case {0}b - difference: PROCEDURE with ';' "
                   "(compare).".format(test_num))
        cmd_str = ("mysqldbcompare.py --server1={0} --server2={1} "
                   "util_test_mysql_client:util_test_mysql_client "
                   "--difftype=sql --changes-for=server1 "
                   "-t".format(srv1_con, srv2_con))
        res = self.run_test_case(0, cmd_str, comment)
        if not res:
            raise MUTLibError("{0}: failed".format(comment))

        # DROP previous diff data
        self.server1.exec_query('DROP PROCEDURE `util_test_mysql_client`.`p1`')
        self.server2.exec_query('DROP PROCEDURE `util_test_mysql_client`.`p1`')

        # Different PROCEDURE using ';' and DEFINER with wildcard (%) in body.
        test_num += 1
        diff_data1 = (
            "USE util_test_mysql_client;\n"
            "DELIMITER //\n"
            "CREATE DEFINER=`root`@`%` PROCEDURE `p1`(OUT param1 INT)\n"
            "BEGIN\n"
            "  SELECT COUNT(*) INTO param1 FROM mysql.user;\n"
            "END //\n"
            "DELIMITER ;\n"
        )
        diff_data2 = (
            "USE util_test_mysql_client;\n"
            "DELIMITER //\n"
            "CREATE DEFINER=`root`@`%` PROCEDURE `p1`(OUT param1 INT)\n"
            "BEGIN\n"
            "  -- test comment\n"
            "  SELECT COUNT(*) INTO param1 FROM mysql.user;\n"
            "END //\n"
            "DELIMITER ;\n"
        )
        # Write diff data for server1 to file
        with open(_INPUT_SQL_FILE, 'w') as sql_file:
            sql_file.write(diff_data1)
            # Execute diff SQL data on server1
        res = self.exec_mysql_client_cmd(self.server1, _INPUT_SQL_FILE)
        # Write diff data for server2 to file
        with open(_INPUT_SQL_FILE, 'w') as sql_file:
            sql_file.write(diff_data2)
            # Execute diff SQL data on server2
        res = self.exec_mysql_client_cmd(self.server2, _INPUT_SQL_FILE)
        # Get diff SQL for server1 from server2
        comment = ("Test case {0}a - difference: PROCEDURE DEFINER with "
                   "wildcard (%) (get diff SQL).".format(test_num))
        cmd_str = ("mysqldbcompare.py --server1={0} --server2={1} "
                   "util_test_mysql_client:util_test_mysql_client "
                   "--difftype=sql --changes-for=server1 "
                   "-t > {2}".format(srv1_con, srv2_con, _INPUT_SQL_FILE))
        res = self.run_test_case(1, cmd_str, comment)
        if not res:
            raise MUTLibError("{0}: failed".format(comment))
            # Append output difference for server1 to result file.
        if self.debug:
            print "\nContents of SQL file:"
        sql_file = open(_INPUT_SQL_FILE, 'r')
        for line in sql_file:
            self.results.append(line)
            if self.debug:
                print line,
        sql_file.close()
        # Execute diff from mysqldbcompare on server1
        res = self.exec_mysql_client_cmd(self.server1, _INPUT_SQL_FILE)
        # Compare server1 and server2 - No difference expected now
        comment = ("Test case {0}b - difference: PROCEDURE DEFINER with "
                   "wildcard (%) (compare).".format(test_num))
        cmd_str = ("mysqldbcompare.py --server1={0} --server2={1} "
                   "util_test_mysql_client:util_test_mysql_client "
                   "--difftype=sql --changes-for=server1 "
                   "-t".format(srv1_con, srv2_con))
        res = self.run_test_case(0, cmd_str, comment)
        if not res:
            raise MUTLibError("{0}: failed".format(comment))

        return True

    def get_result(self):
        return self.compare(__name__, self.results)

    def record(self):
        return self.save_result_file(__name__, self.results)

    def cleanup(self):
        # Drop created databases
        self.drop_db(self.server1, 'util_test_mysql_client')
        self.drop_db(self.server2, 'util_test_mysql_client')

        # Remove auxiliary files
        try:
            os.unlink(self.res_fname)
        except OSError:
            pass
        try:
            os.unlink(_INPUT_SQL_FILE)
        except OSError:
            pass
        return True
